function hh_init(; b_grid, a_grid, z_grid, eis)
    Va_temp = (0.6 .+ 1.1 .* b_grid .+ a_grid') .^ (-1 / eis)
    Va = reshape(Va_temp, 1, size(Va_temp)...).* ones(size(z_grid, 1), 1, 1)
    Vb_temp = (0.5 .+ b_grid .+ 1.2 .* a_grid') .^ (-1 / eis)
    Vb = reshape(Vb_temp, 1, size(Vb_temp)...) .* ones(size(z_grid, 1), 1, 1)
    return Va, Vb
end

function adjustment_costs(; a, a_grid, ra, χ0, χ1, χ2)
    χ = get_Ψ_and_deriv(a, reshape(a_grid, 1, 1, :), ra, χ0, χ1, χ2)[1]
    return (χ,)
end

function marginal_cost_grid(; a_grid, ra, χ0, χ1, χ2)
    Ψ1 = get_Ψ_and_deriv(a_grid, a_grid', ra, χ0, χ1, χ2)[2]
    return (Ψ1,)
end

function _hh(; Va_p, Vb_p, a_grid, b_grid, z_grid, e_grid, k_grid, β, eis, rb, ra, χ0, χ1, χ2, Ψ1)
    Wb = β .* Vb_p
    Wa = β .* Va_p
    W_ratio = Wa ./ Wb
    
    i, π = lhs_equals_rhs_interpolate(W_ratio, 1 .+ Ψ1)

    a_endo_unc = apply_coord(i, π, a_grid)
    c_endo_unc = apply_coord(i, π, Wb)
    c_endo_unc[c_endo_unc .< 0] .= zero(eltype(c_endo_unc))
    c_endo_unc = c_endo_unc .^ (-eis)

    b_endo = (c_endo_unc .+ a_endo_unc .+ addouter(-z_grid, b_grid, -(1 .+ ra) .* a_grid) .+ get_Ψ_and_deriv(a_endo_unc, reshape(a_grid, 1, 1, :), ra, χ0, χ1, χ2)[1]) ./ (1 .+ rb)

    i, π = interpolate_coord(permutedims(b_endo, (1, 3, 2)), b_grid)
    a_unc = permutedims(apply_coord(i, π, permutedims(a_endo_unc, (1, 3, 2))), (1, 3, 2))
    b_unc = permutedims(apply_coord(i, π, b_grid), (1, 3, 2))

    lhs_con = W_ratio[:, 1:1, :] ./ (1 .+ reshape(k_grid, 1, :, 1))
    i, π = lhs_equals_rhs_interpolate(lhs_con, 1 .+ Ψ1)

    a_endo_con = apply_coord(i, π, a_grid)
    c_endo_con = apply_coord(i, π, Wb[:, 1:1, :])
    c_endo_con[c_endo_con .< 0] .= zero(eltype(c_endo_con))
    c_endo_con = (1 .+ reshape(k_grid, 1, :, 1)).^(-eis) .* c_endo_con .^ (-eis)

    b_endo = (c_endo_con .+ a_endo_con .+ addouter(-z_grid, fill(b_grid[1], size(k_grid, 1)), -(1 .+ ra) .* a_grid) .+ get_Ψ_and_deriv(a_endo_con, reshape(a_grid, 1, 1, :), ra, χ0, χ1, χ2)[1]) ./ (1 .+ rb)

    a_con = permutedims(interpolate_y(permutedims(b_endo, (1, 3, 2)), b_grid, permutedims(a_endo_con, (1, 3, 2))), (1, 3, 2))

    a, b = copy(a_unc), copy(b_unc)
    b[b .<= b_grid[1]] .= b_grid[1]
    a[b .<= b_grid[1]] .= a_con[b .<= b_grid[1]]

    Ψ, _, Ψ2, = get_Ψ_and_deriv(a, reshape(a_grid, 1, 1, :), ra, χ0, χ1, χ2)

    c = addouter(z_grid, (1 .+ rb) .* b_grid, (1 .+ ra) .* a_grid) .- Ψ .- a .- b
    uc = c .^ (-1 / eis)
    uce = reshape(e_grid, :, 1, 1) .* uc

    Va = (1 .+ ra .- Ψ2) .* uc
    Vb = (1 .+ rb) .* uc

    return Va, Vb, a, b, c, uce
end

hh_twoasset = het(_hh; exogenous = "Pi", policy = ["b", "a"], backward = ["Vb", "Va"], hetinputs = [marginal_cost_grid], hetoutputs = [adjustment_costs], backward_init = hh_init)

function get_Ψ_and_deriv(ap, a, ra, χ0, χ1, χ2)
    a_with_return = (1 .+ ra) .* a
    a_change = ap .- a_with_return
    abs_a_change = abs.(a_change)
    sign_change = sign.(a_change)

    adj_denominator = a_with_return .+ χ0
    core_factor = (abs_a_change ./ adj_denominator) .^ (χ2 - 1)

    Ψ = χ1 ./ χ2 .* abs_a_change .* core_factor
    Ψ1 = χ1 .* sign_change .* core_factor
    Ψ2 = -(1 .+ ra) .* (Ψ1 .+ (χ2 .- 1) .* Ψ ./ adj_denominator)
    return Ψ, Ψ1, Ψ2
end

function matrix_times_first_dim(A, X)
    return reshape(A * reshape(X, size(X, 1), :), size(X))
end

function addouter(z, b, a)
    return reshape(z, :, 1, 1) .+ reshape(b, 1, :, 1) .+ reshape(a, 1, 1, :)
end

function lhs_equals_rhs_interpolate_internal(lhs, rhs)
    ni, nj = size(rhs)
    @assert length(lhs) ==  ni

    iout = zeros(Int, nj)
    piout = zeros(nj)

    i = 1
    for j ∈ 1:nj
        while true
            if lhs[i] < rhs[i, j] 
                break
            elseif i < nj
                i += 1
            else
                break
            end
        end

        if i == 1
            iout[j] = 1
            piout[j] = 1
        else
            iout[j] = i - 1
            err_upper = rhs[i, j] - lhs[i]
            err_lower = rhs[i - 1, j] - lhs[i - 1]
            piout[j] = err_upper / (err_upper - err_lower)
        end
    end

    return iout, piout
end

function lhs_equals_rhs_interpolate(lhs, rhs)
    nj = size(rhs)[end]

    ndims(lhs)==1 && (lhs = reshape(lhs, (1, 1, :)))
    ndims(lhs)==2 && (lhs = reshape(lhs, (1, size(lhs)...)))
    ndims(rhs)==2 && (rhs = reshape(rhs, (1, size(rhs)...)))

    maxdim1 = max(size(lhs, 1))
    maxdim2 = max(size(lhs, 2), size(rhs, 1))
    (size(lhs, 2) ≠ maxdim2) && (lhs = repeat(lhs, 1, maxdim2))
    (size(rhs, 1) ≠ maxdim2) && (rhs = repeat(rhs, maxdim2))

    iout = zeros(Int, maxdim1, maxdim2, nj)
    piout = zeros(maxdim1, maxdim2, nj)

    for dim1 ∈ 1:maxdim1
        for dim2 ∈ 1:maxdim2
            iout[dim1, dim2, :], piout[dim1, dim2, :] = lhs_equals_rhs_interpolate_internal(lhs[dim1, dim2, :], rhs[dim2, :, :])
        end
    end

    iout_return, piout_return = iout, piout
    (maxdim1==1) && ((iout_return, piout_return) = (iout[1, :, :], piout[1, :, :]))
    (maxdim2==1) && ((iout_return, piout_return) = (iout[:, 1, :], piout[:, 1, :]))
    (maxdim1==1 && maxdim2==1) && ((iout_return, piout_return) = (iout[1, 1, :], piout[1, 1, :]))

    return iout_return, piout_return
end
